# -*- coding:utf-8 -*-

import numpy as np
import tensorflow as tf
from config.glob.global_pool import global_pool


class DatasetOpDto:
    """
    存放数据集参数
    对外提供 数据集参数以及init_iterator、next_batch操作接口
    """
    def __init__(self, sess, handle=None):
        self.__sess = sess
        self.__handle = handle  # 自定义处理方法
        self.__num = 0  # 数据集样本数
        self.__batch_size = 0  # 每个batch样本数
        self.__dataset = None  # tf.data.Dataset数据集
        self.__init_op = None  # init iterator操作
        self.__next_op = None  # next batch操作

    @property
    def num(self):
        return self.__num

    @num.setter
    def num(self, num):
        self.__num = num

    @property
    def batch_size(self):
        return self.__batch_size

    @batch_size.setter
    def batch_size(self, batch_size):
        self.__batch_size = batch_size

    @property
    def dataset(self):
        return self.__dataset

    @property
    def batch_num(self):
        """
        计算batch_num=num//batch_size
        :return:
        """
        if self.batch_size > 0:
            return self.__num // self.__batch_size
        elif self.__batch_size < 0 or self.__num < 0:
            raise ValueError('训练集样本数和batch_size必须为正数')
        else:
            raise ValueError('batch_size需要初始化')

    def build_dataset(self, features, labels, is_train=True):
        """
        构建dataset
        :param features:
        :param labels:
        :param is_train:
        :return:
        """
        xs_dtype = np.int64 if global_pool.config.xs_dtype == 'int' else np.float32
        ys_dtype = np.int64 if global_pool.config.ys_dtype == 'int' else np.float32

        self.__dataset = tf.data.Dataset.from_tensor_slices(
            (features.astype(xs_dtype), labels.astype(ys_dtype))
        )
        if is_train:  # 训练集
            self.__dataset = self.__dataset.shuffle(self.__num)
            if self.__handle:
                self.__dataset = self.__handle(self.__dataset)
            self.__dataset = self.__dataset.batch(self.__batch_size)  # 分批
            self.__dataset = self.__dataset.prefetch(self.__batch_size)
        else:  # 验证集和测试集
            if self.__handle:
                self.__dataset = self.__handle(self.__dataset)
            self.__dataset = self.__dataset.batch(self.__num)

    def build_iterator(self):
        """
        构建iterator
        :return:
        """
        # 建立iterator
        iterator = tf.data.Iterator.from_structure(self.__dataset.output_types, self.__dataset.output_shapes)
        self.__init_op = iterator.make_initializer(self.__dataset)
        self.__next_op = iterator.get_next()  # next

    def init_itreator(self):
        """
        执行init_op
        :return:
        """
        self.__sess.run(self.__init_op)

    def next_batch(self):
        """
        执行next_op
        :return:
        """
        x, y = self.__sess.run(self.__next_op)
        return x, y
